"""
Tests for encryption key passing to background threads.

Verifies that research threads can access encrypted databases with the correct password.
"""

import threading
import pytest
from unittest.mock import patch

from tests.test_utils import add_src_to_path

add_src_to_path()


class TestThreadContextPasswordStorage:
    """Test that thread context correctly stores and retrieves passwords."""

    def test_set_and_get_password_in_same_thread(self):
        """Password set via set_search_context should be retrievable."""
        from src.local_deep_research.utilities.thread_context import (
            set_search_context,
            get_search_context,
        )

        set_search_context(
            {
                "username": "test_user",
                "user_password": "secret123",
            }
        )

        ctx = get_search_context()
        assert ctx is not None
        assert ctx.get("user_password") == "secret123"

    def test_context_includes_all_fields(self):
        """All fields in context should be preserved."""
        from src.local_deep_research.utilities.thread_context import (
            set_search_context,
            get_search_context,
        )

        context = {
            "research_id": "res_123",
            "username": "alice",
            "user_password": "pass456",
            "custom_field": "custom_value",
        }
        set_search_context(context)

        retrieved = get_search_context()
        assert retrieved["research_id"] == "res_123"
        assert retrieved["username"] == "alice"
        assert retrieved["user_password"] == "pass456"
        assert retrieved["custom_field"] == "custom_value"


class TestThreadContextIsolation:
    """Test that thread context is properly isolated between threads."""

    def test_child_thread_does_not_inherit_context(self):
        """A new thread should NOT see the parent thread's context."""
        from src.local_deep_research.utilities.thread_context import (
            set_search_context,
            get_search_context,
        )

        child_result = []

        def child_thread():
            ctx = get_search_context()
            child_result.append(ctx)

        # Set context in main thread
        set_search_context({"user_password": "main_thread_pass"})

        # Child thread should not see it
        thread = threading.Thread(target=child_thread)
        thread.start()
        thread.join(timeout=2)

        assert child_result[0] is None, (
            "Child thread should not inherit parent's context"
        )

    def test_child_thread_can_set_own_context(self):
        """A child thread can set and retrieve its own context."""
        from src.local_deep_research.utilities.thread_context import (
            set_search_context,
            get_search_context,
        )

        child_result = []

        def child_thread():
            set_search_context({"user_password": "child_pass"})
            ctx = get_search_context()
            child_result.append(ctx)

        thread = threading.Thread(target=child_thread)
        thread.start()
        thread.join(timeout=2)

        assert child_result[0] is not None
        assert child_result[0]["user_password"] == "child_pass"


class TestGetUserDbSessionPasswordRetrieval:
    """Test that get_user_db_session retrieves password from thread context."""

    def test_password_retrieved_from_thread_context(self):
        """get_user_db_session should use password from thread context."""
        from src.local_deep_research.utilities.thread_context import (
            set_search_context,
        )
        from src.local_deep_research.database.session_context import (
            get_user_db_session,
        )
        from src.local_deep_research.database.encrypted_db import db_manager
        from src.local_deep_research.database import thread_local_session

        set_search_context(
            {
                "username": "test_user",
                "user_password": "thread_context_password",
            }
        )

        captured_passwords = []

        def capture(username, password):
            captured_passwords.append(password)
            raise Exception("Captured")

        with patch(
            "src.local_deep_research.database.session_context.has_app_context",
            return_value=False,
        ):
            with patch.object(db_manager, "has_encryption", True):
                with patch.object(
                    thread_local_session,
                    "get_metrics_session",
                    side_effect=capture,
                ):
                    try:
                        with get_user_db_session("test_user"):
                            pass
                    except Exception:
                        pass

        assert len(captured_passwords) == 1
        assert captured_passwords[0] == "thread_context_password"

    def test_none_password_causes_error_with_encryption(self):
        """If password is None and encryption is enabled, should raise error."""
        from src.local_deep_research.utilities.thread_context import (
            set_search_context,
        )
        from src.local_deep_research.database.session_context import (
            get_user_db_session,
            DatabaseSessionError,
        )
        from src.local_deep_research.database.encrypted_db import db_manager

        # Set context with None password
        set_search_context(
            {
                "username": "test_user",
                "user_password": None,
            }
        )

        with patch(
            "src.local_deep_research.database.session_context.has_app_context",
            return_value=False,
        ):
            with patch.object(db_manager, "has_encryption", True):
                with pytest.raises(DatabaseSessionError) as exc_info:
                    with get_user_db_session("test_user"):
                        pass

                assert "requires password" in str(exc_info.value).lower()
